Table of Contents
Mistral
Paper: Mistral 7B, Oct 2023. Code: GItHub: mistralai/mistral-src
Transformer vs Mistral vs Llama
- Differences between the vanilla Transformer and Mistral:
- Transformer: Encoder-decoder model.
- Mistral: Decoder-only model, similar to Llama.
- Differences with Llama:
- Self-attention:
- Uses sliding window attention. Not present in Llama.
- Rolling buffer KV cache for inferencing.. Not present in Llama.
- Utilizes group query attention same as Llama.
- Feed-forward layer:
- Utilizes SiLu function instead of ReLU or SwiGLU.
- Mixture of expert in FFN. Not present in Llama.
- Mistral 8X 7B: Eight feed-forward networks in parallel as experts for mixture of expert.
- Self-attention:
- Model structure:
- Input converted into embeddings.
- Transformer block repeated 32 times in Mistral.
- Output of each layer fed to next layer as input.
- Output of last layer sent to RMS Norm, linear, and softmax to produce model output.
- Similarity to other Transformer models:
- Utilizes blocks consisting of multi-head attention, normalization, feed-forward, and other normalization.
- Referred to as Transformer block, encoder block, or decoder block based on content.
- Normalization precedes feed-forward and self-attention layers as opposed to original Transformer model where normalization follows these layers.
- Utilizes residual connections.
Note
What is Decoder-only model?
- It looks like a Transformer model with the encoder side removed.
- Plus, It does not have cross-attention.
- It is used in Mistral, Llama,
- An encoder-only model lacks linear and softmax, resembling the encoder side of the Transformer. BERT is an example of an encoder-only model.
Mistral Architecture
- dim of embedding = 4096
- number of transformer blocks = 32
- Group query attention:
- number of heads for Q = 32. Each head has 128 dimensions.
- number of heads for K and V = 8.
- Every 4 heads of Q are grouped together.
- Window size for sliding window attention = 4096(7B model). Not mentioned for 8X model.
- context size = 8192(7B model), 32000(8 X 7B model).
Sliding Window Attention(SWA) vs Self Attention
- In self-attention, each token can attend to all other tokens in the sequence.
- In sliding window attention, each token can only attend to tokens within a fixed window size.
- Sliding window attention is more efficient for long sequences as it reduces the quadratic complexity of self-attention.
-
Sliding window attention captures local context efficiently by focusing on a fixed-size window around each token.
-
In SWA, the hidden state at position i in layer k can attend to hidden states from the preceding layer within the range of positions i — W to i, allowing access to tokens at a distance of up to W * k tokens.
- By employing a window size of W = 4096, SWA theoretically achieves an attention span of approximately 131K tokens.
KV Cache
First refer to the KV Cache for better understanding.
- Rolling Buffer KV Cache
- Pre-fill and chunking
Rolling Buffer KV Cache
In SWA, the KV cache is implemented as a rolling buffer. The rolling buffer KV cache is used to store the key and value vectors for each token in the sequence. The cache is updated as the model processes each token in the sequence. The rolling buffer KV cache allows the model to access the key and value vectors for each token within the window size efficiently.
For example, If sliding window is 10 tokens, then the KV cache will store the key and value vectors for the last 10 tokens in the sequence.
So, when 11th token is processed, the KV cache will update the key and value vectors for the 11th token and remove the key and value vectors for the 1st token from the cache.
This is done statically. First, we replace 1st token with 11th token. So we have 11th, 2nd, 3rd, 4th, 5th, 6th, 7th, 8th, 9th, 10th tokens in the cache.
We keep track of the last updated token with a pointer.
def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor:
assert cache.ndim == 3 # (W, H, D)
position = seqlen % cache.shape[0] # Position of the last updated token.
if seqlen < cache.shape[0]:
return cache[:seqlen] # Since cache is not full, ignore the empty slots.
elif position == 0:
return cache # Cache is full and last updated token is at the end. So return the cache as it is.
else:
return torch.cat((cache[position:], cache[:position]), dim=0) # Rotate the cache around the last updated token.
Pre-fill and Chunking
Pre-fill and Chunking: When generating sequences, the model predicts tokens sequentially since each token depends on the previous ones. However since prompts are known beforehand, the cache can be pre-filled with them. If a prompt is too long, it’s divided into chunks, and the cache is filled chunk by chunk. The attention mask operates over both the cache and the chunk.
# From Mistral codebase
# First chunk. KV cache is empty.
if first_prefill:
assert all([pos == 0 for pos in seqpos]), (seqpos)
# size = only considering the incoming sequence length = seqlens
mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.sliding_window)
# Subsequent chunks. KV cache is not empty.
elif subsequent_prefill:
# Q size = seqlens = incoming sequence length
# KV size =
mask = BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens,
kv_seqlen=[s + cached_s.clamp(max=self.sliding_window).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)]
).make_local_attention_from_bottomright(self.sliding_window)
# Token Generation
else:
# Q size = seqlens = incoming sequence length
# KV size = kv_seqlens + cached_elements
mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=seqlens,
kv_padding=self.sliding_window,
kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.sliding_window).tolist()
)
return RotatingCacheInputMetadata(
positions=positions,
to_cache_mask=to_cache_mask,
cached_elements=cached_elements,
cache_positions=cache_positions[to_cache_mask],
prefill=first_prefill or subsequent_prefill,
mask=mask,
seqlens=seqlens,
)



